# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Literal, Optional

import numpy as np
import nvtx
import torch

from physicsnemo.models.diffusion import EDMPrecond
from physicsnemo.utils.patching import GridPatching2D

# ruff: noqa: E731


# NOTE: use two wrappers for apply, to avoid recompilation when input shape changes
@torch.compile()
def _apply_wrapper_Cin_channels(patching, input, additional_input=None):
    """
    Apply the patching operation to the input tensor with :math:`C_{in}` channels.
    """
    return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _apply_wrapper_Cout_channels_no_grad(patching, input, additional_input=None):
    """
    Apply the patching operation to an input tensor with :math:`C_{out}`
    channels that does not require gradients.
    """
    return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _apply_wrapper_Cout_channels_grad(patching, input, additional_input=None):
    """
    Apply the patching operation to an input tensor with :math:`C_{out}`
    channels that requires gradients.
    """
    return patching.apply(input=input, additional_input=additional_input)


@torch.compile()
def _fuse_wrapper(patching, input, batch_size):
    return patching.fuse(input=input, batch_size=batch_size)


def _apply_wrapper_select(
    input: torch.Tensor, patching: GridPatching2D | None
) -> Callable:
    """
    Select the correct patching wrapper based on the input tensor's requires_grad attribute.
    If patching is None, return the identity function.
    If patching is not None, return the appropriate patching wrapper.
    If input.requires_grad is True, return _apply_wrapper_Cout_channels_grad.
    If input.requires_grad is False, return
    _apply_wrapper_Cout_channels_no_grad.
    """
    if patching:
        if input.requires_grad:
            return _apply_wrapper_Cout_channels_grad
        else:
            return _apply_wrapper_Cout_channels_no_grad
    else:
        return lambda patching, input, additional_input=None: input


@nvtx.annotate(message="deterministic_sampler", color="red")
def deterministic_sampler(
    net: torch.nn.Module,
    latents: torch.Tensor,
    img_lr: torch.Tensor,
    class_labels: Optional[torch.Tensor] = None,
    randn_like: Callable = torch.randn_like,
    patching: Optional[GridPatching2D] = None,
    mean_hr: Optional[torch.Tensor] = None,
    lead_time_label: Optional[torch.Tensor] = None,
    num_steps: int = 18,
    sigma_min: Optional[float] = None,
    sigma_max: Optional[float] = None,
    rho: float = 7.0,
    solver: Literal["heun", "euler"] = "heun",
    discretization: Literal["vp", "ve", "iddpm", "edm"] = "edm",
    schedule: Literal["vp", "ve", "linear"] = "linear",
    scaling: Literal["vp", "none"] = "none",
    epsilon_s: float = 1e-3,
    C_1: float = 0.001,
    C_2: float = 0.008,
    M: int = 1000,
    alpha: float = 1.0,
    S_churn: int = 0,
    S_min: float = 0.0,
    S_max: float = float("inf"),
    S_noise: float = 1.0,
    dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
    r"""
    Generalized sampler, representing the superset of all sampling methods
    discussed in the paper `Elucidating the Design Space of Diffusion-Based
    Generative Models (EDM) <https://arxiv.org/abs/2206.00364>`_.

    This function integrates an ODE (probability flow) or SDE over multiple
    time-steps to generate samples from the diffusion model provided by the
    argument 'net'. It can be used to combine multiple choices to
    design a custom sampler, including multiple integration solver,
    discretization method, noise schedule, and so on.

    Parameters
    ----------
    net : torch.nn.Module
        The diffusion model to use in the sampling process.
    latents : torch.Tensor
        The latent random noise used as the initial condition for the
        stochastic ODE.
    img_lr : torch.Tensor
        Low-resolution input image for conditioning the diffusion process.
        Passed as a keywork argument to the model ``net``.
    class_labels : Optional[torch.Tensor]
        Labels of the classes used as input to a class-conditionned
        diffusion model. Passed as a keyword argument to the model ``net``.
        If provided, it must be a tensor containing  integer values.
        Defaults to ``None``, in which case it is ignored.
    randn_like: Callable
        Random Number Generator to generate random noise that is added
        during the stochastic sampling. Must have the same signature as
        torch.randn_like and return torch.Tensor. Defaults to
        torch.randn_like.
    patching : Optional[GridPatching2D], default=None
        A patching utility for patch-based diffusion. Implements methods to
        extract patches from an image and batch the patches along dim=0.
        Should also implement a ``fuse`` method to reconstruct the original
        image from a batch of patches. See
        :class:`~physicsnemo.utils.patching.GridPatching2D` for details. By
        default ``None``, in which case non-patched diffusion is used.
    mean_hr : Optional[Tensor], optional
        Optional tensor containing mean high-resolution images for
        conditioning. Must have same height and width as ``img_lr``, with shape
        :math:`(B_{hr}, C_{hr}, H, W)`  where the batch dimension
        :math:`B_{hr}` can be either 1, either equal to ``batch_size``, or can be omitted. If
        :math:`B_{hr} = 1` or is omitted, ``mean_hr`` will be expanded to match the shape
        of ``img_lr``. By default ``None``.
    lead_time_label : Optional[Tensor], optional
        Lead-time labels to pass to the model, shape ``(batch_size,)``.
        If not provided, the model is called without a lead-time label input.
    num_steps : Optional[int]
        Number of time-steps for the stochastic ODE integration. Defaults
        to 18.
    sigma_min : Optional[float]
        Minimum noise level for the diffusion process. ``sigma_min``,
        ``sigma_max``, and ``rho`` are used to compute the time-step
        discretization, based on the choice of discretization. For the
        default choice (``discretization='heun'``), the noise level schedule
        is computed as:
        :math:`\sigma_i = (\sigma_{max}^{1/\rho} + i / (\text{num_steps} - 1) * (\sigma_{min}^{1/\rho} - \sigma_{max}^{1/\rho}))^{\rho}`.
        For other choices of ``discretization``, see details in the EDM
        paper. Defaults to ``None``, in which case defaults values depending
        of the specified discretization are used.
    sigma_max : Optional[float]
        Maximum noise level for the diffusion process. See ``sigma_min`` for
        details. Defaults to ``None``, in which case defaults values depending
        of the specified discretization are used.
    rho : float, optional
        Exponent used in the noise schedule. See ``sigma_min`` for details.
        Only used when ``discretization="heun"``. Values in the range
        [5, 10] produce better images. Lower values lead to truncation errors
        equalized over all time steps. Defaults to 7.
    solver : Literal["heun", "euler"]
        The numerical method used to integrate the stochastic ODE. ``"euler"``
        is 1st order solver, which is faster but produces lower-quality
        images. ``"heun"`` is 2nd order, more expensive, but produces
        higher-quality images. Defaults to ``"heun"``.
    discretization : Literal["vp", "ve", "iddpm", "edm"]
        The method to discretize time-steps :math:`t_i` in the
        diffusion process. See the EDM paper for details. Defaults to
        ``"edm"``.
    schedule : Literal["vp", "ve", "linear"]
        The type of noise level schedule.  Defaults to ``"linear"``. If
        ``schedule="ve"``, then :math:`\sigma(t) = \sqrt{t}`. If
        ``schedule="linear"``, then :math:`\sigma(t) = t`. If ``schedule="vp"``,
        see EDM paper for details. Defaults to ``"linear"``.
    scaling : Literal["vp", "none"]
        The type of time-dependent signal scaling :math:`s(t)`, such that
        :math:`x = s(t) \hat{x}`. See EDM paper for details on the ``"vp"``
        scaling. Defaults to ``"none"``, in which case :math:`s(t)=1`.
    epsilon_s : float, optional
        Parameter to compute both the noise level schedule and the
        time-step discetization. Only used when ``discretization="vp"`` or
        ``schedule="vp"``. Ignored in other cases. Defaults to 1e-3.
    C_1 : float, optional
        Parameters to compute the time-step discetization. Only used when
        ``discretization="iddpm"``. Defaults to 0.001.
    C_2 : float, optional
        Same as for C_1. Only used when ``discretization="iddpm"``. Defaults to
        0.008.
    M : int, optional
        Same as for C_1 and C_2. Only used when ``discretization="iddpm"``.
        Defaults to 1000.
    alpha : float, optional
        Controls (i.e. multiplies) the step size :math:`t_{i+1} -
        \hat{t}_i` in the stochastic sampler, where :math:`\hat{t}_i` is
        the temporarily increased noise level. Defaults to 1.0, which is
        the recommended value.
    S_churn : int, optional
        Controls the amount of stochasticty injected in the SDE in the
        stochatsic sampler. Larger values of ``S_churn`` lead to larger values
        of :math:`\hat{t}_i`, which in turn lead to injecting more
        stochasticity in the SDE by  Defaults to 0, which means no
        stochasticity is injected.
    S_min : float, optional
        ``S_min`` and ``S_max`` control the time-step range over which
        stochasticty is injected in the SDE. Stochasticity is injected
        through :math:`\hat{t}_i` for time-steps :math:`t_i` such that
        :math:`S_{min} \leq t_i \leq S_{max}`. Defaults to 0.0.
    S_max : float, optional
        See ``S_min``. Defaults to ``float("inf")``.
    S_noise : float, optional
        Controls the amount of stochasticty injected in the SDE in the
        stochatsic sampler. Added signal noise is proportinal to
        :math:`\epsilon_i` where :math:`\epsilon_i \sim \mathcal{N}(0, S_{noise}^2)`. Defaults
        to 1.0.
    dtype : torch.dtype, optional
        Controls the precision used for sampling
    Returns
    -------
        torch.Tensor:
            Generated batch of samples. Same shape as the input ``latents``.
    """

    # conditioning = [mean_hr, img_lr, global_lr]
    x_lr = img_lr
    if mean_hr is not None:
        if mean_hr.shape[-2:] != img_lr.shape[-2:]:
            raise ValueError(
                f"mean_hr and img_lr must have the same height and width, "
                f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}."
            )
        x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1)

    # Safety check on type of patching
    if patching is not None and not isinstance(patching, GridPatching2D):
        raise ValueError("patching must be an instance of GridPatching2D.")

    # Safety check: if patching is used then img_lr and latents must have same
    # height and width, otherwise there is mismatch in the number
    # of patches extracted to form the final batch_size.
    if patching:
        if img_lr.shape[-2:] != latents.shape[-2:]:
            raise ValueError(
                f"img_lr and latents must have the same height and width, "
                f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. "
            )
    # img_lr and latents must also have the same batch_size, otherwise mismatch
    # when processed by the network
    if img_lr.shape[0] != latents.shape[0]:
        raise ValueError(
            f"img_lr and latents must have the same batch size, but found "
            f"{img_lr.shape[0]} vs {latents.shape[0]}."
        )

    if solver not in ["euler", "heun"]:
        raise ValueError(f"Unknown solver {solver}")
    if discretization not in ["vp", "ve", "iddpm", "edm"]:
        raise ValueError(f"Unknown discretization {discretization}")
    if schedule not in ["vp", "ve", "linear"]:
        raise ValueError(f"Unknown schedule {schedule}")
    if scaling not in ["vp", "none"]:
        raise ValueError(f"Unknown scaling {scaling}")

    # Helper functions for VP & VE noise level schedules.
    vp_sigma = (
        lambda beta_d, beta_min: lambda t: (
            np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1
        )
        ** 0.5
    )
    vp_sigma_deriv = (
        lambda beta_d, beta_min: lambda t: 0.5
        * (beta_min + beta_d * t)
        * (sigma(t) + 1 / sigma(t))
    )
    vp_sigma_inv = (
        lambda beta_d, beta_min: lambda sigma: (
            (beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min
        )
        / beta_d
    )
    ve_sigma = lambda t: t.sqrt()
    ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
    ve_sigma_inv = lambda sigma: sigma**2

    # Select default noise level range based on the specified time step discretization.
    if sigma_min is None:
        vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
        sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[
            discretization
        ]
    if sigma_max is None:
        vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
        sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization]

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    batch_size = img_lr.shape[0]
    # input and position padding + patching
    if patching:
        # Patched conditioning [x_lr, mean_hr]
        # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x)
        x_lr = _apply_wrapper_Cin_channels(
            patching=patching, input=x_lr, additional_input=img_lr
        )

        # Function to select the correct positional embedding for each patch
        def patch_embedding_selector(emb):
            # emb: (N_pe, image_shape_y, image_shape_x)
            # return: (batch_size * patch_num, N_pe, patch_shape_y, patch_shape_x)
            return patching.apply(emb.expand(batch_size, -1, -1, -1))

    else:
        patch_embedding_selector = None

    # Compute corresponding betas for VP.
    vp_beta_d = (
        2
        * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1))
        / (epsilon_s - 1)
    )
    vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d

    # Define time steps in terms of noise level.
    step_indices = torch.arange(num_steps, dtype=dtype, device=latents.device)
    if discretization == "vp":
        orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
        sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
    elif discretization == "ve":
        orig_t_steps = (sigma_max**2) * (
            (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1))
        )
        sigma_steps = ve_sigma(orig_t_steps)
    elif discretization == "iddpm":
        u = torch.zeros(M + 1, dtype=dtype, device=latents.device)
        alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
        for j in torch.arange(M, 0, -1, device=latents.device):  # M, ..., 1
            u[j - 1] = (
                (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1
            ).sqrt()
        u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
        sigma_steps = u_filtered[
            ((len(u_filtered) - 1) / (num_steps - 1) * step_indices)
            .round()
            .to(torch.int64)
        ]
    else:
        sigma_steps = (
            sigma_max ** (1 / rho)
            + step_indices
            / (num_steps - 1)
            * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
        ) ** rho

    # Define noise level schedule.
    if schedule == "vp":
        sigma = vp_sigma(vp_beta_d, vp_beta_min)
        sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
        sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
    elif schedule == "ve":
        sigma = ve_sigma
        sigma_deriv = ve_sigma_deriv
        sigma_inv = ve_sigma_inv
    else:
        sigma = lambda t: t
        sigma_deriv = lambda t: 1
        sigma_inv = lambda sigma: sigma

    # Define scaling schedule.
    if scaling == "vp":
        s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
        s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
    else:
        s = lambda t: 1
        s_deriv = lambda t: 0

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sigma_inv(net.round_sigma(sigma_steps))
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])  # t_N = 0

    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(dtype) * (sigma(t_next) * s(t_next))

    optional_args = {}
    if lead_time_label is not None:
        optional_args["lead_time_label"] = lead_time_label
    if patching:
        optional_args["embedding_selector"] = patch_embedding_selector

    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):  # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = (
            min(S_churn / num_steps, np.sqrt(2) - 1)
            if S_min <= sigma(t_cur) <= S_max
            else 0
        )
        t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
        x_hat = s(t_hat) / s(t_cur) * x_cur + (
            sigma(t_hat) ** 2 - sigma(t_cur) ** 2
        ).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step. Perform patching operation on score tensor if patch-based
        # generation is used denoised = net(x_hat, t_hat,
        # class_labels,lead_time_label=lead_time_label)

        h = t_next - t_hat
        x_hat_batch = _apply_wrapper_select(input=x_hat, patching=patching)(
            patching=patching, input=x_hat
        ).to(latents.device)

        if isinstance(net, EDMPrecond):
            # Conditioning info is passed as keyword arg
            denoised = net(
                x_hat_batch / s(t_hat),
                sigma(t_hat),
                condition=x_lr,
                class_labels=class_labels,
                **optional_args,
            ).to(dtype)
        else:
            denoised = net(
                x_hat_batch / s(t_hat),
                x_lr,
                sigma(t_hat),
                class_labels,
                **optional_args,
            ).to(dtype)

        if patching:
            # Un-patch the denoised image
            # (batch_size, C_out, img_shape_y, img_shape_x)
            denoised = _fuse_wrapper(
                patching=patching, input=denoised, batch_size=batch_size
            )

        d_cur = (
            sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)
        ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == "euler" or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            # Patched input
            # (batch_size * patch_num, C_out, patch_shape_y, patch_shape_x)
            x_prime_batch = _apply_wrapper_select(input=x_prime, patching=patching)(
                patching=patching, input=x_prime
            ).to(latents.device)

            if isinstance(net, EDMPrecond):
                # Conditioning info is passed as keyword arg
                denoised = net(
                    x_prime_batch / s(t_prime),
                    sigma(t_prime),
                    condition=x_lr,
                    class_labels=class_labels,
                    **optional_args,
                ).to(dtype)
            else:
                denoised = net(
                    x_prime_batch / s(t_prime),
                    x_lr,
                    sigma(t_prime),
                    class_labels,
                    **optional_args,
                ).to(dtype)

            if patching:
                # Un-patch the denoised image
                # (batch_size, C_out, img_shape_y, img_shape_x)
                denoised = _fuse_wrapper(
                    patching=patching, input=denoised, batch_size=batch_size
                )

            d_prime = (
                sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)
            ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
            x_next = x_hat + h * (
                (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime
            )

    return x_next
